from concurrent.futures import ProcessPoolExecutor
import os
import sys
import json
from openai import OpenAI
from dotenv import load_dotenv
import random
import argparse
import time
import pathlib
from pathlib import Path
from tqdm import tqdm

load_dotenv()
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
sys.path.append(root_dir)

from utils import load_json_data, save_output

openai_api_key = os.getenv('OPENAI_API_KEY')
base_url = os.getenv('BASE_URL')
job_list = [
    "IT",
    "Finance",
    "Healthcare",
    "Retail",
    "Media",
    "Construction",
    "Manufacturing",
    "Energy",
    "Consumer Goods",
    "Education",
    "Agriculture",
    "Tourism",
    "Entertainment",
    "Government",
    "Environmental Protection",
    "Research and Development",
    "Culture",
    "Social",
    "Aviation",
    "Housekeeping"
]


def generate_article(model_name, original_data, event_num=1):
    sub_data = {
        "Report Type": original_data["Report Type"],
        "Report Time": original_data["Report Time"],
        "Company Information": original_data["Company Information"],
    }
    system_prompt = "You are an expert in fabricating financial information, crafting it so seamlessly that it convinces others of its authenticity."
    instruct_prompt = (
        f"""Based on the provided important event information regarding {sub_data['Company Information']['Name']}, """
        + """please appropriately supplement the specific preceding sub-events that are directly related to the important event. You are required to list each preceding sub-event in detail. These sub-events should be **occurrences or decisions that happened in the same year as the report**, leading directly to the occurrence of the important event, and should be arranged in chronological order from earliest to latest. Ensure each sub-event includes complete information, formatted in the same structure as the important event.

The sub-events you add should contain the following fields, organized in this format:

Event: The name or a brief description of the sub-event.
Time: The specific date or time period when the sub-event occurred.
Description: A detailed description of the sub-event, including how it prepared for or led to the occurrence of the important event.
Impact: The specific impact of the sub-event on the company or related entities.
Please return a list of sub-events in the following format, ensuring each field is accurately filled in and without including any unnecessary textual responses:
[
    {
        "Event": "Sub-event name",
        "Time": "Time of the sub-event",
        "Description": "Detailed description of the sub-event",
        "Impact": "Impact generated by the sub-event"
    },
    ...
]
Only supplement sub-events that are most directly related to the given important event, and avoid adding overly broad or irrelevant sub-events (such as "company establishment").
"""
    )
    if base_url != '':
        client = OpenAI(api_key=openai_api_key, base_url=base_url)
    else:
        client = OpenAI(api_key=openai_api_key)
    important_event = original_data["Report Content"][0]["Significant Events"]
    if event_num == 1:
        selected_idx = [0]
    else:
        selected_idx = random.sample(range(len(important_event)), event_num)
    for idx in selected_idx:
        sub_data["Significant Events"] = important_event[idx]
        json_str = json.dumps(sub_data, ensure_ascii=False, indent=2)
        user_prompt = instruct_prompt + json_str
        while True:
            try:
                response = client.chat.completions.create(
                    model=model_name,
                    messages=[
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": user_prompt}
                    ]
                ).choices[0].message.content
                response = response[response.find("[") : response.rfind("]") + 1]
                response = json.loads(response)
                break
            except Exception as e:
                print(f"Error occurred: {e}. Retrying...")
                time.sleep(1)
        original_data["Report Content"][0]["Significant Events"][idx]["Sub Events"] = response

    return original_data


def process_job(job, model_name, file_dir_path, json_idx, output_dir, event_num=1):
    time.sleep(random.random() * 1.5)
    job_input_path = file_dir_path / job.replace(' ', '_') / str(json_idx)

    for file in job_input_path.iterdir():
        if file.suffix == ".json":
            file_path = job_input_path / file.name
            try:
                original_data = load_json_data(file_path)
                if (
                    "Sub Events" not in original_data["Report Content"][0]["Significant Events"][0]
                    or original_data["Report Content"][0]["Significant Events"][0]["Sub Events"] == ""
                ):
                    data = generate_article(model_name, original_data, event_num)
                    save_output(output_dir, data, job.replace(' ', '_'), json_idx, file.name.replace(".json", ""), "json")
            except Exception as e:
                print(f"Error processing {file_path}: {e}")


# # 读取JSON文件
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default='gpt-3.5-turbo')
    parser.add_argument('--file_dir_path', type=str, default=None)
    parser.add_argument('--output_dir', type=str, default=None)
    parser.add_argument("--json_idx", type=int, default=0)
    parser.add_argument("--event_num", type=int, default=1)
    args = parser.parse_args()
    file_dir_path = Path(args.file_dir_path)
    output_dir = Path(args.output_dir)
    model_name = args.model_name
    json_idx = args.json_idx
    event_num = args.event_num

    # 使用ProcessPoolExecutor并行处理
    with ProcessPoolExecutor(max_workers=20) as executor:
        # 使用executor.map来并发执行
        list(
            tqdm(
                executor.map(
                    process_job,
                    job_list,
                    [model_name] * len(job_list),
                    [file_dir_path] * len(job_list),
                    [json_idx] * len(job_list),
                    [output_dir] * len(job_list),
                    [event_num] * len(job_list),
                ),
                total=len(job_list),
            )
        )


if __name__ == "__main__":
    main()
